data = {
    'input_ids': [1, 2, 3, 0],
    'attention_mask': [1, 1, 1, 0],
    'labels': [-100, 2, 0, -100],
}

print(data)
labels = data.pop('labels')
print(labels)
print(data)
